package leetcode.code1325;

import leetcode.helper.tree.TreeNode;

public class Solution {
	public TreeNode removeLeafNodes(TreeNode root, int target) {
		if (root == null) {
			return root;
		}
		TreeNode left = this.removeLeafNodes(root.left, target);
		TreeNode right = this.removeLeafNodes(root.right, target);
		if (left == null && right == null && root.val == target) {
			return null;
		}
		root.left = left;
		root.right = right;
		return root;
	}
}
